#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# BSD 3-Clause License

# Copyright (c) 2016 ~ 2025, NXROBO Ltd.
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.

# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.

# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import copy
import threading
import rospy
import numpy as np
import math
from sagittarius_joy.msg import arm_joy
import tf.transformations as tf_transformations
import sdk_sagittarius_arm.msg
import std_msgs.msg
import modern_robotics as mr
from sensor_msgs.msg import JointState


class AngleManipulation:

    # Inverts a homogeneous transformation matrix
    @staticmethod
    def transInv(T):
        R, p = T[:3, :3], T[:3, 3]
        Rt = np.array(R).T
        return np.r_[np.c_[Rt, -np.dot(Rt, p)], [[0, 0, 0, 1]]]

    # Calculates 2D Rotation Matrix given a desired yaw angle
    @staticmethod
    def yawToRotationMatrix(yaw):
        R_z = np.array([[math.cos(yaw),    -math.sin(yaw)],
                        [math.sin(yaw),    math.cos(yaw)],
                        ])
        return R_z


class SagittariusCtrlTool:
    ''' 机械臂简单控制工具箱

    '''
    _real_joint_degree = [0, 0, 0, 0, 0, 0]
    _joint_degree = [0, 0, 0, 0, 0, 0]
    lower_joint_limits = [-2, -1.57, -1.48, -2.9, -1.8, -3.1]
    upper_joint_limits = [2, 1.4, 1.8, 2.9, 1.6, 3.1]
    Slist = np.array([[0.0, 0.0, -1.0, 0.0, 0.045, 0.0],
                      [0.0, -1.0, 0.0, 0.125, 0.0, -0.045],
                      [0.0, -1.0, 0.0, 0.304, 0.0, -0.078],
                      [-1.0, 0.0, 0.0, 0.0, -0.304, -0.0002],
                      [0.0, -1.0, 0.0, 0.304, 0.0, -0.24685],
                      [-1.0, 0.0, 0.0, 0.0, -0.304, 0.00045]]).T
    M = np.array([[1.0, 0.0, 0.0, 0.363937],
                  [0.0, 1.0, 0.0, 0.00045],
                  [0.0, 0.0, 1.0, 0.304],
                  [0.0, 0.0, 0.0, 1.0]])

    _time_start = False
    _time_last = 0

    def __init__(self) -> None:
        # publish torque command and joint angle to arm
        self.torque_pub = rospy.Publisher(
            "control_torque", std_msgs.msg.String, queue_size=3)
        self.joint_pub = rospy.Publisher(
            "joint/commands", sdk_sagittarius_arm.msg.ArmRadControl, queue_size=3)
        self.gripper_pub = rospy.Publisher(
            "gripper/command", std_msgs.msg.Float64, queue_size=3)

        # subscribe joint states
        self.joint_sub = rospy.Subscriber(
            'joint_states', JointState, callback=self.joint_status_callback, queue_size=30)

        # get current posture
        self.joint_mutex = threading.Lock()
        rospy.sleep(1)
        self._joint_degree = copy.deepcopy(self._real_joint_degree)
        self.set_gripper(0)

    def get_FK(self, joint_list):
        ''' 正向运动学

        return: 4x4 end-effect 矩阵
        '''
        return mr.FKinSpace(self.M, self.Slist, joint_list)

    def get_IK(self, matrix, custom_guess=None, joint_num=6):
        ''' 逆向运动学

        return: 如果有解, 返回关节角度列表, 否则返回 None
        '''
        solution_found = False
        default_guess = [[0.0] * 6 for i in range(3)]
        default_guess[1][0] = math.radians(-120)
        default_guess[2][0] = math.radians(120)
        if(custom_guess is None):
            initial_guesses = default_guess
        else:
            initial_guesses = [list(custom_guess)] + default_guess
        T_sd = matrix
        for i in range(len(initial_guesses)):
            if (joint_num == 6):
                theta_list, success = mr.IKinSpace(
                    self.Slist, self.M, T_sd, initial_guesses[i], 0.001, 0.001)
            else:
                theta_list, success = mr.IKinSpace(
                    self.Slist[:, :joint_num], self.M, T_sd, initial_guesses[i][:joint_num], 0.001, 0.001)
            if (success):
                for i in range(len(theta_list)):
                    if (theta_list[i] <= -math.pi * 2):
                        theta_list[i] %= -math.pi * 2
                    elif (theta_list[i] >= math.pi * 2):
                        theta_list[i] %= math.pi * 2
                    if (round(theta_list[i], 3) < round(self.lower_joint_limits[i], 3)):
                        theta_list[i] += math.pi * 2
                    elif (round(theta_list[i], 3) > round(self.upper_joint_limits[i], 3)):
                        theta_list[i] -= math.pi * 2
                    if not (self.lower_joint_limits[i] <= theta_list[i] <= self.upper_joint_limits[i]):
                        if (len(initial_guesses) == 4 and i == 0):
                            rospy.logwarn(
                                "custom_guess faild, using default guess.")
                        break
                solution_found = True
                break

        if (solution_found):
            return theta_list.tolist()
        else:
            return None

    def joint_status_callback(self, msg):
        ''' joint states 回调函数
        '''
        with self.joint_mutex:
            self._real_joint_degree = list(copy.deepcopy(msg).position[:6])

    def set_torque(self, enable):
        ''' 设置扭矩输出

        enable: True 为输出扭矩, False 为不输出扭矩

        '''
        if (enable):
            self.torque_pub.publish(std_msgs.msg.String("open"))
        else:
            self.torque_pub.publish(std_msgs.msg.String("off"))
        return True

    def set_gripper(self, line):
        ''' 设置夹爪距离

        line: 夹爪距离, 单位是 m
        '''
        self.gripper_pub.publish(std_msgs.msg.Float64(line))
        return True

    def get_single_joint(self, index):
        ''' 返回单个关节的角度

        index: 关节编号

        return: 关节角度, 单位为弧度
        '''
        return self._joint_degree[index - 1]

    def set_single_joint(self, index, degree):
        ''' 设置单个关节角度

        index: 关节编号
        '''
        self._joint_degree[index] = degree
        self.set_all_joint(self._joint_degree)
        return True

    def set_all_joint(self, degrees):
        ''' 设置所有关节角度

        degrees: 关节角度的列表

        return: 设置结果, 成功为True, 失败为False
        '''
        for i in range(len(degrees)):
            if not (self.lower_joint_limits[i] <= degrees[i] <= self.upper_joint_limits[i]):
                return False
        self.joint_pub.publish(
            sdk_sagittarius_arm.msg.ArmRadControl(rad=degrees))
        self._joint_degree = degrees
        return True

    def get_ee_posture(self):
        ''' 获取 end-effect 实际的位姿

        return: 4x4 end-effect 矩阵
        '''
        return self.get_FK(self._real_joint_degree)

    def get_ee_posture_command(self):
        ''' 获取 end-effect 最后一次命令的位姿

        return: 4x4 end-effect 矩阵
        '''
        return self.get_FK(self._joint_degree)

    def set_ee_posture(self, matrix):
        ''' 设置 end-effect 的位姿

        matrix: 4x4 end-effect 矩阵
        return: 设置结果, 成功为True, 失败为False
        '''
        theta = self.get_IK(matrix, self._real_joint_degree)
        if (theta is not None):
            if (self.set_all_joint(theta)):
                return True
        if (self._time_start == False):
            rospy.loginfo("The posture can not found or out of range.")
            self._time_last = rospy.Time.now().to_sec()
            self._time_start = True
        else:
            if (rospy.Time.now().to_sec() - self._time_last > 0):
                self._time_start = False
        return False


class SagittariusJoyDemo:
    ''' 手柄控制的例程

    控制方式

    左摇杆前/后: end-effect x 加/减
    左摇杆左/右：腰部旋转 左摆/右摆
    左摇杆摁下: 无

    右摇杆前/后: end-effect pitch 往下看/往上看
    右摇杆左/右: end-effect y 加/减
    右摇杆摁下: 特殊设置 摆正夹爪

    LB(L1): 高度减少
    RB(R1): 高度增加

    BACK(SHARE): 逆时针翻滚
    START(OPTION): 顺时针翻滚

    LT(L2): 夹爪减少间距
    RT(R2): 夹爪增加间距

    方向键上/下: 速度增加/减少
    方向键上/下: 无

    B(cycle): 预设姿态 home
    A(cross): 预设姿态 sleep

    长按 power 1秒: 扭矩输出切换
    '''

    def __init__(self):
        self.waist_step = 0.01  # 一号舵机步长
        self.rotate_step = 0.01  # 旋转矩阵步长
        self.translate_step = 0.002  # 变换矩阵步长
        self.gripper_pressure_step = 0.001  # 夹爪步长
        self.current_step_scale = 2  # 步长放大倍数
        self.current_torque_status = True  # 扭矩输出标志位
        self.current_gripper_spacing = 0  # 夹爪默认值
        self.joy_msg = arm_joy()
        self.joy_mutex = threading.Lock()
        self.rate = rospy.Rate(15)
        self.armbot = SagittariusCtrlTool()
        self.T_sy = np.identity(4)  # yaw
        self.T_yb = np.identity(4)
        self.update_T_yb()
        rospy.Subscriber("commands/joy_processed",
                         arm_joy, self.joy_control_cb)

    def update_speed(self, step_scale):
        ''' 速度更新

        通过更新步长的放大倍速来修改速度
        step_scale: 步长倍速, 默认 2
        '''
        self.current_step_scale = step_scale
        rospy.loginfo("current step scale is %.2f." % self.current_step_scale)

    def update_T_yb(self):
        ''' 更新 end-effect 位姿

        分离姿态 yaw 为 T_sy, IK运算前需点乘
        '''
        T_sb = self.armbot.get_ee_posture_command()
        rpy = tf_transformations.euler_from_matrix(T_sb[:3, :3])
        # rpy = ang.rotationMatrixToEulerAngles(T_sb[:3, :3])
        self.T_sy[:2, :2] = AngleManipulation.yawToRotationMatrix(rpy[2])
        self.T_yb = np.dot(AngleManipulation.transInv(self.T_sy), T_sb)

    def update_gripper_spacing(self, gripper_spacing):
        ''' 设置夹爪距离

        gripper_spacing 夹爪距离, 单位为 m
        '''
        self.current_gripper_spacing = gripper_spacing
        self.armbot.set_gripper(self.current_gripper_spacing)
        rospy.loginfo("Gripper spacing is at %.2f." %
                      (self.current_gripper_spacing))

    def joy_control_cb(self, msg):
        ''' 获取手柄按键的回调函数

        '''
        with self.joy_mutex:
            self.joy_msg = copy.deepcopy(msg)

        # Check the speed_cmd
        if (msg.speed_cmd == arm_joy.SPEED_INC and self.current_step_scale < 4):
            self.update_speed(self.current_step_scale + 0.1)
        elif (msg.speed_cmd == arm_joy.SPEED_DEC and self.current_step_scale > 0.1):
            self.update_speed(self.current_step_scale - 0.1)

        # Check the torque_cmd
        if (msg.torque_cmd == arm_joy.TORQUE_ON):
            self.armbot.set_torque(True)
            self.current_torque_status = True
        elif (msg.torque_cmd == arm_joy.TORQUE_OFF):
            self.armbot.set_torque(False)
            self.current_torque_status = False

    # @brief Main control loop to manipulate the arm
    def controller(self):
        ''' 控制机械臂的函数

        主函数循环回调
        '''
        waist_step = self.waist_step * self.current_step_scale
        translate_step = self.translate_step * self.current_step_scale
        rotate_step = self.rotate_step * self.current_step_scale

        if (self.current_torque_status == False):
            return

        with self.joy_mutex:
            msg = copy.deepcopy(self.joy_msg)

        # Check the gripper_cmd
        if (msg.gripper_spacing_cmd == arm_joy.GRIPPER_SPACING_INC and self.current_gripper_spacing > -0.03):
            self.update_gripper_spacing(
                self.current_gripper_spacing - self.gripper_pressure_step)
        elif (msg.gripper_spacing_cmd == arm_joy.GRIPPER_SPACING_DEC and self.current_gripper_spacing < 0):
            self.update_gripper_spacing(
                self.current_gripper_spacing + self.gripper_pressure_step)

        # Check the pose_cmd
        if (msg.pose_cmd != 0):
            if (msg.pose_cmd == arm_joy.HOME_POSE):
                self.armbot.set_all_joint([0, 0, 0, 0, 0, 0])
            elif (msg.pose_cmd == arm_joy.SLEEP_POSE):
                self.armbot.set_all_joint([0, 1.4, -1.48, 0, -0.4851, 0])
            elif (msg.pose_cmd == arm_joy.UP_POSE):
                pass
            self.update_T_yb()
            return

        # Check the waist_cmd
        if (msg.ee_yaw_cmd != 0):
            waist_position = self.armbot.get_single_joint(1)
            if (msg.ee_yaw_cmd == arm_joy.EE_YAW_LEFT):
                waist_position -= waist_step
                if (waist_position >= self.armbot.upper_joint_limits[0]):
                    waist_position = self.armbot.upper_joint_limits[0]
                self.armbot.set_single_joint(0, waist_position)
            elif (msg.ee_yaw_cmd == arm_joy.EE_YAW_RIGHT):
                waist_position += waist_step
                if (waist_position <= self.armbot.lower_joint_limits[0]):
                    waist_position = self.armbot.lower_joint_limits[0]
                self.armbot.set_single_joint(0, waist_position)
            self.update_T_yb()

        # Check the position and orientation cmd
        else:
            position_changed = msg.ee_x_cmd + msg.ee_y_cmd + msg.ee_z_cmd
            orientation_changed = msg.ee_roll_cmd + msg.ee_pitch_cmd + msg.ee_yaw_cmd

            if (position_changed + orientation_changed == 0
                and msg.reset_cmd != arm_joy.ORIENTATION_RESET
                    and msg.reset_cmd != arm_joy.POSITION_RESET):
                return

            # Copy the most recent T_yb transform into a temporary variable
            T_yb = np.array(self.T_yb)

            # Is position has changed
            if (position_changed):
                # check ee_x_cmd
                if (msg.ee_x_cmd == arm_joy.EE_X_INC):
                    T_yb[0, 3] += translate_step
                elif (msg.ee_x_cmd == arm_joy.EE_X_DEC):
                    T_yb[0, 3] -= translate_step

                # check ee_y_cmd
                if (msg.ee_y_cmd == arm_joy.EE_Y_INC and T_yb[0, 3] > 0.3):
                    T_yb[1, 3] += translate_step
                elif (msg.ee_y_cmd == arm_joy.EE_Y_DEC and T_yb[0, 3] > 0.3):
                    T_yb[1, 3] -= translate_step

                # check ee_z_cmd
                if (msg.ee_z_cmd == arm_joy.EE_Z_INC):
                    T_yb[2, 3] += translate_step
                elif (msg.ee_z_cmd == arm_joy.EE_Z_DEC):
                    T_yb[2, 3] -= translate_step

            # Is position need reset
            elif (msg.reset_cmd == arm_joy.POSITION_RESET):
                return

            # Is orientation has changed
            if (orientation_changed != 0):
                roll, pitch, yaw = tf_transformations.euler_from_matrix(
                    T_yb[:3, :3])

                # check ee_roll_cmd
                if (msg.ee_roll_cmd == arm_joy.EE_ROLL_CW):
                    roll += rotate_step
                elif (msg.ee_roll_cmd == arm_joy.EE_ROLL_CCW):
                    roll -= rotate_step

                # check ee_pitch_cmd
                if (msg.ee_pitch_cmd == arm_joy.EE_PITCH_DOWN):
                    pitch += rotate_step
                elif (msg.ee_pitch_cmd == arm_joy.EE_PITCH_UP):
                    pitch -= rotate_step

                # check ee_yaw_cmd
                # if (msg.ee_yaw_cmd == arm_joy.EE_YAW_LEFT):
                #     yaw += rotate_step
                # elif (msg.ee_yaw_cmd == arm_joy.EE_YAW_RIGHT):
                #     yaw -= rotate_step

                T_yb[:3, :3] = tf_transformations.euler_matrix(roll, pitch, yaw)[
                    :3, :3]

            # Is orientation need reset
            elif (msg.reset_cmd == arm_joy.ORIENTATION_RESET):
                T_yb[1, 3] = 0
                roll, pitch, yaw = tf_transformations.euler_from_matrix(
                    T_yb[:3, :3])
                T_yb[:3, :3] = tf_transformations.euler_matrix(0, pitch, 0)[
                    :3, :3]

            # Get desired transformation matrix of the end-effector w.r.t. the base frame
            T_sd = np.dot(self.T_sy, T_yb)
            success = self.armbot.set_ee_posture(T_sd)
            if (success):
                self.T_yb = np.array(T_yb)


if __name__ == '__main__':
    rospy.init_node('sagittarius_joy_demo')
    arm = SagittariusJoyDemo()
    rospy.sleep(5)

    rospy.loginfo("***********  控制方式  **************")
    rospy.loginfo("左摇杆前/后: end-effect x 加/减")
    rospy.loginfo("左摇杆左/右：腰部旋转 左摆/右摆")
    rospy.loginfo("左摇杆摁下: 无")
    rospy.loginfo("")
    rospy.loginfo("右摇杆前/后: end-effect pitch 往下看/往上看")
    rospy.loginfo("右摇杆左/右: end-effect y 加/减")
    rospy.loginfo("右摇杆摁下: 特殊设置 摆正夹爪")
    rospy.loginfo("")
    rospy.loginfo("LB(L1): 高度减少")
    rospy.loginfo("RB(R1): 高度增加")
    rospy.loginfo("")
    rospy.loginfo("BACK(SHARE): 逆时针翻滚")
    rospy.loginfo("START(OPTION): 顺时针翻滚")
    rospy.loginfo("")
    rospy.loginfo("LT(L2): 夹爪减少间距")
    rospy.loginfo("RT(R2): 夹爪增加间距")
    rospy.loginfo("")
    rospy.loginfo("方向键上/下: 速度增加/减少")
    rospy.loginfo("方向键上/下: 无")
    rospy.loginfo("")
    rospy.loginfo("B(cycle): 预设姿态 home")
    rospy.loginfo("A(cross): 预设姿态 sleep")
    rospy.loginfo("")
    rospy.loginfo("长按 power 1秒: 扭矩输出切换")
    rospy.loginfo("*********************************")

    while not rospy.is_shutdown():
        arm.controller()
        arm.rate.sleep()
